{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7478e0e2-b7af-4fd4-b44e-ca58e0c31b71",
   "metadata": {},
   "source": [
    "# Getting started with `PatchTST`\n",
    "## Direct forecasting example\n",
    "\n",
    "This notebook demonstrates the usage of a `PatchTST` model for a multivariate time series forecasting task. This notebook has a dependecy on HuggingFace [transformers](https://github.com/huggingface/transformers) repo. For details related to model architecture, refer to the [PatchTST paper](https://arxiv.org/abs/2211.14730)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f63ae353-96df-4380-89f6-1e6cebf684fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard\n",
    "import os\n",
    "import random\n",
    "\n",
    "# Third Party\n",
    "from transformers import (\n",
    "    EarlyStoppingCallback,\n",
    "    PatchTSTConfig,\n",
    "    PatchTSTForPrediction,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    ")\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "# First Party\n",
    "from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
    "from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor\n",
    "from tsfm_public.toolkit.util import select_by_index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a826c4f3-1c6c-4088-b6af-f430f45fd380",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set seed for reproducibility\n",
    "SEED = 42\n",
    "torch.manual_seed(SEED)\n",
    "random.seed(SEED)\n",
    "np.random.seed(SEED)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e4eb9be-c19f-448f-a4bd-c600e068633f",
   "metadata": {},
   "source": [
    "## Load and prepare datasets\n",
    "\n",
    "In the next cell, please adjust the following parameters to suit your application:\n",
    "- `dataset_path`: path to local .csv file, or web address to a csv file for the data of interest. Data is loaded with pandas, so anything supported by\n",
    "`pd.read_csv` is supported: (https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html).\n",
    "- `timestamp_column`: column name containing timestamp information, use None if there is no such column\n",
    "- `id_columns`: List of column names specifying the IDs of different time series. If no ID column exists, use []\n",
    "- `forecast_columns`: List of columns to be modeled\n",
    "- `context_length`: The amount of historical data used as input to the model. Windows of the input time series data with length equal to\n",
    "context_length will be extracted from the input dataframe. In the case of a multi-time series dataset, the context windows will be created\n",
    "so that they are contained within a single time series (i.e., a single ID).\n",
    "- `forecast_horizon`: Number of time stamps to forecast in future.\n",
    "- `train_start_index`, `train_end_index`: the start and end indices in the loaded data which delineate the training data.\n",
    "- `valid_start_index`, `valid_end_index`: the start and end indices in the loaded data which delineate the validation data.\n",
    "- `test_start_index`, `test_end_index`: the start and end indices in the loaded data which delineate the test data.\n",
    "- `patch_length`: The patch length for the `PatchTST` model. Recommended to have a value so that `context_length` is divisible by it.\n",
    "- `num_workers`: Number of dataloder workers in pytorch dataloader.\n",
    "- `batch_size`: Batch size. \n",
    "The data is first loaded into a Pandas dataframe and split into training, validation, and test parts. Then the pandas dataframes are converted\n",
    "to the appropriate torch dataset needed for training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d4c1e812-f2d6-4ccb-a79c-47879b562d85",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"ETTh1\"\n",
    "num_workers = 8  # Reduce this if you have low number of CPU cores\n",
    "batch_size = 32  # Reduce if not enough GPU memory available\n",
    "context_length = 512\n",
    "forecast_horizon = 96\n",
    "patch_length = 12"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "19ca5a76-64d4-4f8d-92f5-67f76c1685fd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading target dataset: ETTh1\n"
     ]
    }
   ],
   "source": [
    "print(f\"Loading target dataset: {dataset}\")\n",
    "dataset_path = f\"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/{dataset}.csv\"\n",
    "timestamp_column = \"date\"\n",
    "id_columns = []\n",
    "forecast_columns = [\"HUFL\", \"HULL\", \"MUFL\", \"MULL\", \"LUFL\", \"LULL\", \"OT\"]\n",
    "train_start_index = None  # None indicates beginning of dataset\n",
    "train_end_index = 12 * 30 * 24\n",
    "\n",
    "# we shift the start of the validation/test period back by context length so that\n",
    "# the first validation/test timestamp is immediately following the training data\n",
    "valid_start_index = 12 * 30 * 24 - context_length\n",
    "valid_end_index = 12 * 30 * 24 + 4 * 30 * 24\n",
    "\n",
    "test_start_index = 12 * 30 * 24 + 4 * 30 * 24 - context_length\n",
    "test_end_index = 12 * 30 * 24 + 8 * 30 * 24"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0e856ce6-11e9-4c8c-9aeb-3e4a8b05f3eb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TimeSeriesPreprocessor {\n",
       "  \"context_length\": 64,\n",
       "  \"feature_extractor_type\": \"TimeSeriesPreprocessor\",\n",
       "  \"id_columns\": [],\n",
       "  \"input_columns\": [\n",
       "    \"HUFL\",\n",
       "    \"HULL\",\n",
       "    \"MUFL\",\n",
       "    \"MULL\",\n",
       "    \"LUFL\",\n",
       "    \"LULL\",\n",
       "    \"OT\"\n",
       "  ],\n",
       "  \"output_columns\": [\n",
       "    \"HUFL\",\n",
       "    \"HULL\",\n",
       "    \"MUFL\",\n",
       "    \"MULL\",\n",
       "    \"LUFL\",\n",
       "    \"LULL\",\n",
       "    \"OT\"\n",
       "  ],\n",
       "  \"prediction_length\": null,\n",
       "  \"processor_class\": \"TimeSeriesPreprocessor\",\n",
       "  \"scale_outputs\": false,\n",
       "  \"scaler_dict\": {\n",
       "    \"0\": {\n",
       "      \"copy\": true,\n",
       "      \"feature_names_in_\": [\n",
       "        \"HUFL\",\n",
       "        \"HULL\",\n",
       "        \"MUFL\",\n",
       "        \"MULL\",\n",
       "        \"LUFL\",\n",
       "        \"LULL\",\n",
       "        \"OT\"\n",
       "      ],\n",
       "      \"mean_\": [\n",
       "        7.937742245659508,\n",
       "        2.0210386567335163,\n",
       "        5.079770601157927,\n",
       "        0.7461858799957015,\n",
       "        2.781762386375555,\n",
       "        0.7884531235540096,\n",
       "        17.1282616982271\n",
       "      ],\n",
       "      \"n_features_in_\": 7,\n",
       "      \"n_samples_seen_\": 8640,\n",
       "      \"scale_\": [\n",
       "        5.812749409143771,\n",
       "        2.0901046504076,\n",
       "        5.518793579036245,\n",
       "        1.9263792741329822,\n",
       "        1.0235226594952194,\n",
       "        0.6302366362251923,\n",
       "        9.176491024944335\n",
       "      ],\n",
       "      \"var_\": [\n",
       "        33.78805569350125,\n",
       "        4.368537449655475,\n",
       "        30.457082568011693,\n",
       "        3.710937107809115,\n",
       "        1.0475986345001667,\n",
       "        0.39719821764044544,\n",
       "        84.20798753088393\n",
       "      ],\n",
       "      \"with_mean\": true,\n",
       "      \"with_std\": true\n",
       "    }\n",
       "  },\n",
       "  \"scaling\": true,\n",
       "  \"time_series_task\": \"forecasting\",\n",
       "  \"timestamp_column\": \"date\"\n",
       "}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = pd.read_csv(\n",
    "    dataset_path,\n",
    "    parse_dates=[timestamp_column],\n",
    ")\n",
    "\n",
    "train_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=train_start_index,\n",
    "    end_index=train_end_index,\n",
    ")\n",
    "valid_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=valid_start_index,\n",
    "    end_index=valid_end_index,\n",
    ")\n",
    "test_data = select_by_index(\n",
    "    data,\n",
    "    id_columns=id_columns,\n",
    "    start_index=test_start_index,\n",
    "    end_index=test_end_index,\n",
    ")\n",
    "\n",
    "tsp = TimeSeriesPreprocessor(\n",
    "    timestamp_column=timestamp_column,\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    scaling=True,\n",
    ")\n",
    "tsp.train(train_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "678d849d-41fc-450d-a855-1dde27179b31",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = ForecastDFDataset(\n",
    "    tsp.preprocess(train_data),\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")\n",
    "valid_dataset = ForecastDFDataset(\n",
    "    tsp.preprocess(valid_data),\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")\n",
    "test_dataset = ForecastDFDataset(\n",
    "    tsp.preprocess(test_data),\n",
    "    id_columns=id_columns,\n",
    "    input_columns=forecast_columns,\n",
    "    output_columns=forecast_columns,\n",
    "    context_length=context_length,\n",
    "    prediction_length=forecast_horizon,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae939491-8813-44c9-bc3d-d1c6a5764cd4",
   "metadata": {},
   "source": [
    "## Testing with a `PatchTST` model that was trained on the training part of the `ETTh1` data\n",
    "\n",
    "A pre-trained model (on `ETTh1` data) is available at [ibm/patchtst-etth1-forecasting](https://huggingface.co/ibm/patchtst-etth1-forecasting)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b839a219-14b9-4445-aa37-f0e466a35e62",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading pretrained model\n",
      "Done\n"
     ]
    }
   ],
   "source": [
    "print(\"Loading pretrained model\")\n",
    "inference_forecast_model = PatchTSTForPrediction.from_pretrained(\n",
    "    \"ibm/patchtst-etth1-forecasting\"\n",
    ")\n",
    "print(\"Done\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9b17c69f-f44d-4ccc-b8d4-5dd34cd1b43e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "CODECARBON : No CPU tracking mode found. Falling back on CPU constant mode.\n",
      "CODECARBON : Failed to match CPU TDP constant. Falling back on a global constant.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Doing testing on Etth1/test data\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='349' max='349' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [349/349 00:02]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.38813790678977966, 'eval_runtime': 2.9498, 'eval_samples_per_second': 944.116, 'eval_steps_per_second': 118.311}\n"
     ]
    }
   ],
   "source": [
    "train_args = TrainingArguments(\n",
    "    output_dir=\"./checkpoint/patchtst/direct/train/output/\",\n",
    "    label_names=[\"future_values\"],\n",
    ")\n",
    "\n",
    "inference_forecast_trainer = Trainer(\n",
    "    model=inference_forecast_model,\n",
    "    args=train_args,\n",
    ")\n",
    "\n",
    "print(\"\\n\\nDoing testing on Etth1/test data\")\n",
    "result = inference_forecast_trainer.evaluate(test_dataset)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19456329-1293-45bf-99c7-e5ccf0534846",
   "metadata": {},
   "source": [
    "## If we want to train from scratch\n",
    "\n",
    "The settings below control the different components in the PatchTST model.\n",
    "  - `num_input_channels`: the number of input channels (or dimensions) in the time series data. This is\n",
    "    automatically set to the number for forecast columns.\n",
    "  - `context_length`: As described above, the amount of historical data used as input to the model.\n",
    "  - `patch_length`: The length of the patches extracted from the context window (of length `context_length`).\n",
    "  - `patch_stride`: The stride used when extracting patches from the context window.\n",
    "  - `d_model`: Dimension of the transformer layers.\n",
    "  - `num_attention_heads`: The number of attention heads for each attention layer in the Transformer encoder.\n",
    "  - `num_hidden_layers`: The number of encoder layers.\n",
    "  - `ffn_dim`: Dimension of the intermediate (often referred to as feed-forward) layer in the encoder.\n",
    "  - `dropout`: Dropout probability for all fully connected layers in the encoder.\n",
    "  - `head_dropout`: Dropout probability used in the head of the model.\n",
    "  - `pooling_type`: Pooling of the embedding. `\"mean\"`, `\"max\"` and `None` are supported.\n",
    "  - `channel_attention`: Activate channel attention block in the Transformer to allow channels to attend each other.\n",
    "  - `scaling`: Whether to scale the input targets via \"mean\" scaler, \"std\" scaler or no scaler if `None`. If `True`, the\n",
    "    scaler is set to `\"mean\"`.\n",
    "  - `loss`: The loss function for the model corresponding to the `distribution_output` head. For parametric\n",
    "    distributions it is the negative log likelihood (`\"nll\"`) and for point estimates it is the mean squared\n",
    "    error `\"mse\"`.\n",
    "  - `pre_norm`: Normalization is applied before self-attention if pre_norm is set to `True`. Otherwise, normalization is\n",
    "    applied after residual block.\n",
    "  - `norm_type`: Normalization at each Transformer layer. Can be `\"BatchNorm\"` or `\"LayerNorm\"`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "226b904e-1ab2-478b-98b4-ce99bc23f1c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = PatchTSTConfig(\n",
    "    do_mask_input=False,\n",
    "    context_length=context_length,\n",
    "    patch_length=patch_length,\n",
    "    num_input_channels=len(forecast_columns),\n",
    "    patch_stride=patch_length,\n",
    "    prediction_length=forecast_horizon,\n",
    "    d_model=128,\n",
    "    num_attention_heads=16,\n",
    "    num_hidden_layers=3,\n",
    "    ffn_dim=512,\n",
    "    dropout=0.2,\n",
    "    head_dropout=0.2,\n",
    "    pooling_type=None,\n",
    "    channel_attention=False,\n",
    "    scaling=\"std\",\n",
    "    loss=\"mse\",\n",
    "    pre_norm=True,\n",
    "    norm_type=\"batchnorm\",\n",
    ")\n",
    "model = PatchTSTForPrediction(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "27812e8c-c0f6-45e3-a075-310929329460",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "CODECARBON : No CPU tracking mode found. Falling back on CPU constant mode.\n",
      "CODECARBON : Failed to match CPU TDP constant. Falling back on a global constant.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Doing forecasting training on Etth1/train\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1512' max='25200' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 1512/25200 03:44 < 58:35, 6.74 it/s, Epoch 6/100]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.387900</td>\n",
       "      <td>0.678949</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.337100</td>\n",
       "      <td>0.705020</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.317400</td>\n",
       "      <td>0.731043</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.304200</td>\n",
       "      <td>0.783324</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.291700</td>\n",
       "      <td>0.816706</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.278900</td>\n",
       "      <td>0.822570</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1512, training_loss=0.3195255925415685, metrics={'train_runtime': 223.643, 'train_samples_per_second': 3591.886, 'train_steps_per_second': 112.68, 'total_flos': 1158800617046016.0, 'train_loss': 0.3195255925415685, 'epoch': 6.0})"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_args = TrainingArguments(\n",
    "    output_dir=\"./checkpoint/patchtst/direct/train/output/\",\n",
    "    overwrite_output_dir=True,\n",
    "    learning_rate=0.0001,\n",
    "    num_train_epochs=100,\n",
    "    do_eval=True,\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    dataloader_num_workers=1,  # num_workers,\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    save_total_limit=3,\n",
    "    logging_dir=\"./checkpoint/patchtst/direct/train/logs/\",  # Make sure to specify a logging directory\n",
    "    load_best_model_at_end=True,  # Load the best model when training ends\n",
    "    metric_for_best_model=\"eval_loss\",  # Metric to monitor for early stopping\n",
    "    greater_is_better=False,  # For loss\n",
    "    label_names=[\"future_values\"],\n",
    ")\n",
    "\n",
    "# Create a new early stopping callback with faster convergence properties\n",
    "early_stopping_callback = EarlyStoppingCallback(\n",
    "    early_stopping_patience=5,  # Number of epochs with no improvement after which to stop\n",
    "    early_stopping_threshold=0.001,  # Minimum improvement required to consider as improvement\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=train_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=valid_dataset,\n",
    "    callbacks=[early_stopping_callback],\n",
    ")\n",
    "\n",
    "print(\"\\n\\nDoing forecasting training on Etth1/train\")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "e06f3931-6b5f-450e-b17d-360ae3984e67",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.3788062334060669,\n",
       " 'eval_runtime': 8.817,\n",
       " 'eval_samples_per_second': 315.868,\n",
       " 'eval_steps_per_second': 9.981,\n",
       " 'epoch': 6.0}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate(test_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b2e8f0a-8367-4c10-84bd-a72e8f21ccc4",
   "metadata": {},
   "source": [
    "#### Sanity check: Compute number of forecasting channels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "a1cdf078-10e8-4788-b7e4-3d1640ab8f8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = trainer.predict(test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "4db1b83f-7381-4fe6-97e4-42ee210ab71d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2785, 96, 7)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.predictions[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa6e5bf8-85a3-4c8f-bcb9-8533c8c13393",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 96 step forecasting horizon, for 7 channels"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
